#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c) 2020 Huawei Technologies Co.,Ltd.
#
# openGauss is licensed under Mulan PSL v2.
# You can use this software according to the terms
# and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
#          http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# ----------------------------------------------------------------------------
# Description  : gs_sshexkey is a utility to create SSH trust among nodes in
# a cluster.
#############################################################################

import sys
import warnings

warnings.simplefilter('ignore', DeprecationWarning)
sys.path.append(sys.path[0] + "/../lib")
import time
import os
import subprocess
import pwd
import grp
import socket
import getpass
import shutil
package_path = os.path.dirname(os.path.realpath(__file__))
ld_path = package_path + "/gspylib/clib"
if 'LD_LIBRARY_PATH' not in os.environ:
    os.environ['LD_LIBRARY_PATH'] = ld_path
    os.execve(os.path.realpath(__file__), sys.argv, os.environ)
if not os.environ.get('LD_LIBRARY_PATH').startswith(ld_path):
    os.environ['LD_LIBRARY_PATH'] = \
        ld_path + ":" + os.environ['LD_LIBRARY_PATH']
    os.execve(os.path.realpath(__file__), sys.argv, os.environ)

from gspylib.common.GaussLog import GaussLog
from gspylib.common.ErrorCode import ErrorCode
from gspylib.threads.parallelTool import parallelTool
from gspylib.common.Common import DefaultValue, ClusterCommand
from gspylib.common.ParameterParsecheck import Parameter
from gspylib.os.gsfile import g_file
from gspylib.os.gsOSlib import g_OSlib

DefaultValue.doConfigForParamiko()
import paramiko

HOSTS_MAPPING_FLAG = "#Gauss OM IP Hosts Mapping"
ipHostInfo = ""
# the tmp path
tmp_files = ""
# tmp file name
TMP_TRUST_FILE = "step_preinstall_file.dat"


class PrintOnScreen():
    """
    class about print on screen
    """

    def __init__(self):
        """
        function : Constructor
        input: NA
        output: NA
        """
        pass

    def log(self, msg):
        """
        function : print log
        input: msg: str
        output: NA
        """
        print(msg)

    def debug(self, msg):
        """
        function : debug
        input: msg: debug message string
        output: NA
        """
        pass

    def logExit(self, msg):
        """
        function : print log and exit
        input: msg: str
        output: NA
        """
        print(msg)
        sys.exit(1)


class GaussCreateTrust():
    """
    class about create trust for user
    """
    log_list = ["addStep",
                "constant",
                "Checking network information.",
                "Successfully checked network information.",
                "Creating the local key file.",
                "Successfully created the local key files.",
                "Appending local ID to authorized_keys.",
                "Successfully appended local ID to authorized_keys.",
                "Updating the known_hosts file.",
                "Successfully updated the known_hosts file.",
                "Appending authorized_key on the remote node.",
                "Successfully appended authorized_key on all remote node.",
                "Checking common authentication file content.",
                "Successfully checked common authentication content.",
                "Distributing SSH trust file to all node.",
                "Successfully distributed SSH trust file to all node.",
                "Verifying SSH trust on all hosts.",
                "Successfully verified SSH trust on all hosts.",
                ]

    def __init__(self):
        """
        function : Constructor
        input: NA
        output: NA
        """
        self.logger = None
        self.hostFile = ""
        self.hostList = []
        self.passwd = []
        self.logFile = ""
        self.localHost = ""
        self.flag = False
        self.logger = None
        self.localID = ""
        self.user = pwd.getpwuid(os.getuid()).pw_name
        self.group = grp.getgrgid(os.getgid()).gr_name
        self.incorrectPasswdInfo = ""
        self.failedToAppendInfo = ""
        self.homeDir = os.path.expanduser("~" + self.user)
        self.sshDir = "%s/.ssh" % self.homeDir
        self.authorized_keys_fname = '%s/.ssh/authorized_keys' % self.homeDir
        self.known_hosts_fname = '%s/.ssh/known_hosts' % self.homeDir
        self.id_rsa_fname = '%s/.ssh/id_rsa' % self.homeDir
        self.id_rsa_pub_fname = self.id_rsa_fname + '.pub'
        self.skipHostnameSet = False
        self.isKeyboardPassword = False
        self.nodeduplicate = False

    def usage(self):
        """
gs_sshexkey is a utility to create SSH trust among nodes in a cluster.

Usage:
  gs_sshexkey -? | --help
  gs_sshexkey -V | --version
  gs_sshexkey -f HOSTFILE [-l LOGFILE] [--skip-hostname-set]
                          

General options:
  -f                          Host file containing the IP address of nodes.
  -h                          Host ip list. Separate multiple nodes with commas(,).
  -l                          Path of log file.
      --skip-hostname-set     Whether to skip hostname setting.
                              (The default value is set.)
  -?, --help                  Show help information for this utility,
                              and exit the command line mode.
  -V, --version               Show version information.
        """
        print(self.usage.__doc__)

    def parseCommandLine(self):
        """
        function: Check parameter from command line
        input : NA
        output: NA
        """
        paraObj = Parameter()
        paraDict = paraObj.ParameterCommandLine("sshexkey")
        if ("helpFlag" in paraDict.keys()):
            self.usage()
            sys.exit(0)

        if ("hostfile" in paraDict.keys()):
            self.hostFile = paraDict.get("hostfile")
        if ("nodename" in paraDict.keys()):
            self.hostList = paraDict.get("nodename")
        if ("logFile" in paraDict.keys()):
            self.logFile = paraDict.get("logFile")
        if ("skipHostnameSet" in paraDict.keys()):
            self.skipHostnameSet = paraDict.get("skipHostnameSet")
        if ("noDeduplicate" in paraDict.keys()):
            self.nodeduplicate = paraDict.get("noDeduplicate")

    def checkParameter(self):
        """
        function: Check parameter from command line
        input : NA
        output: NA
        """
        # check required parameters
        if len(self.hostList) == 0:
            if (self.hostFile == ""):
                self.usage()
                GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"]
                                    % 'f' + ".")
            if (not os.path.exists(self.hostFile)):
                GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50201"]
                                    % self.hostFile)
            if (not os.path.isabs(self.hostFile)):
                GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"]
                                    % self.hostFile)

            # read host file to hostList
            self.readHostFile()

            if (self.hostList == []):
                GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50004"]
                                    % 'f' + " It cannot be empty.")
        else:
            for temp_host in self.hostList:
                if not DefaultValue.isIpValid(temp_host):
                    GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50000"]
                                       % temp_host)
        # check logfile
        if (self.logFile != ""):
            if (not os.path.isabs(self.logFile)):
                GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"]
                                       % self.logFile)

        if (not self.passwd):
            self.passwd = self.getUserPasswd()
            self.isKeyboardPassword = True

    def logOut(self, log_index1, log_index2):
        """
        function:logout
        :param log_index1: index of the log
        :param log_index2: indec of the log
        :return:
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log(GaussCreateTrust.log_list[log_index1],
                                GaussCreateTrust.log_list[log_index2])
            else:
                self.logger.log(GaussCreateTrust.log_list[log_index1])
        else:
            self.logger.log(GaussCreateTrust.log_list[log_index1])

    def readHostFile(self):
        """
        function: read host file to hostList
        input : NA
        output: NA
        """
        inValidIp = []
        try:
            with open(self.hostFile, "r") as f:
                for readLine in f:
                    hostname = readLine.strip().split("\n")[0]
                    if hostname != "" and hostname not in self.hostList:
                        if not DefaultValue.isIpValid(hostname):
                            inValidIp.append(hostname)
                            continue
                        self.hostList.append(hostname)
            if len(inValidIp) > 0:
                GaussLog.exitWithError(ErrorCode.GAUSS_506["GAUSS_50603"]
                                       + "The IP list is:%s." % inValidIp)
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_502["GAUSS_50204"] % "host file"
                            + " Error: \n%s" % str(e))

    def getAllHostsName(self, ip):
        """
        function:
          Connect to all nodes ,then get all hostaname by threading
        precondition:
          1.User's password is correct on each node
        postcondition:
           NA
        input: ip
        output:Dictionary ipHostname,key is IP  and value is hostname
        hideninfo:NA
        """

        ipHostname = {}
        try:
            ssh = paramiko.Transport((ip, 22))
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_512["GAUSS_51220"] % ip
                            + " Error: \n%s" % str(e))
        try:
            ssh.connect(username=self.user, password=self.passwd[0])
        except Exception as e:
            ssh.close()
            raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % ip)

        check_channel = ssh.open_session()
        cmd = "cd"
        check_channel.exec_command(cmd)
        env_msg = check_channel.recv_stderr(9999).decode()
        while True:
            channel_read = check_channel.recv(9999).decode().strip()
            if (len(channel_read) != 0):
                env_msg += str(channel_read)
            else:
                break
        if (env_msg != ""):
            ipHostname["Node[%s]" % ip] = "Output: [" + env_msg \
                                          + " ] print by /etc/profile or" \
                                            " ~/.bashrc, please check it."
            return ipHostname

        channel = ssh.open_session()
        cmd = "hostname"
        channel.exec_command(cmd)
        hostname = channel.recv(9999).decode().strip()
        ipHostname[ip] = hostname
        ssh.close()
        return ipHostname

    def verifyPasswd(self, ssh, pswd=None):
        try:
            ssh.connect(username=self.user, password=pswd)
            return True
        except Exception:
            ssh.close()
            return False

    def parallelGetHosts(self, sshIps):
        parallelResult = {}
        ipHostname = parallelTool.parallelExecute(self.getAllHostsName, sshIps)

        err_msg = ""
        for i in ipHostname:
            for (key, value) in i.items():
                if (key.find("Node") >= 0):
                    err_msg += str(i)
                else:
                    parallelResult[key] = value
        if (len(err_msg) > 0):
            raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"] % err_msg)
        return parallelResult

    def serialGetHosts(self, sshIps):
        serialResult = {}
        invalidIP = ""
        boolInvalidIp = False
        for sshIp in sshIps:
            isPasswdOK = False
            for pswd in self.passwd:
                try:
                    ssh = paramiko.Transport((sshIp, 22))
                except Exception as e:
                    self.logger.debug(str(e))
                    invalidIP += "Incorrect IP address: %s.\n" % sshIp
                    boolInvalidIp = True
                    break

                isPasswdOK = self.verifyPasswd(ssh, pswd)
                if (isPasswdOK):
                    break

            if (boolInvalidIp):
                boolInvalidIp = False
                continue

            if (not isPasswdOK and self.isKeyboardPassword):
                GaussLog.printMessage("Please enter password for current"
                                      " user[%s] on the node[%s]."
                                      % (self.user, sshIp))
                # Try entering the password 3 times interactively
                for i in range(3):
                    KeyboardPassword = getpass.getpass()
                    DefaultValue.checkPasswordVaild(KeyboardPassword)
                    ssh = paramiko.Transport((sshIp, 22))
                    isPasswdOK = self.verifyPasswd(ssh, KeyboardPassword)
                    if (isPasswdOK):
                        self.passwd.append(KeyboardPassword)
                        break
                    else:
                        continue
            # if isKeyboardPassword is true, 3 times after the password is
            # also wrong to throw an unusual exit
            if (not isPasswdOK):
                raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % sshIp)

            cmd = "cd"
            check_channel = ssh.open_session()
            check_channel.exec_command(cmd)
            check_result = check_channel.recv_stderr(9999).decode()
            while True:
                channel_read = check_channel.recv(9999).decode()
                if (len(channel_read) != 0):
                    check_result += str(channel_read)
                else:
                    break

            if (check_result != ""):
                raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"]
                                % check_result + "Please check %s node"
                                                 " /etc/profile or ~/.bashrc"
                                % sshIp)
            else:
                cmd = "hostname"
                channel = ssh.open_session()
                channel.exec_command(cmd)
                while True:
                    hostname = channel.recv(9999).decode().strip()
                    if (len(hostname) != 0):
                        serialResult[sshIp] = hostname
                    else:
                        break
                ssh.close()

        if (invalidIP):
            raise Exception(
                ErrorCode.GAUSS_511["GAUSS_51101"] % invalidIP.rstrip("\n"))

        return serialResult

    def getAllHosts(self, sshIps):
        """
        function:
          Connect to all nodes ,then get all hostaname
        precondition:
          1.User's password is correct on each node
        postcondition:
           NA
        input: sshIps,username,passwd
        output:Dictionary ipHostname,key is IP  and value is hostname
        hideninfo:NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug("Get hostnames for all nodes.", "addStep")
            else:
                self.logger.debug("Get hostnames for all nodes.")
        if (len(self.passwd) == 0):
            self.isKeyboardPassword = True
            GaussLog.printMessage("Please enter password for current user[%s]."
                                  % self.user)
            passwd = getpass.getpass()
            self.passwd.append(passwd)

        if (len(self.passwd) == 1):
            try:
                result = self.parallelGetHosts(sshIps)
            except Exception as e:
                if (self.isKeyboardPassword and str(e).startswith(
                        "[GAUSS-50306] : The password of")):
                    GaussLog.printMessage(
                        "Notice :The password of some nodes is incorrect.")
                    result = self.serialGetHosts(sshIps)
                else:
                    raise Exception(str(e))
        else:
            result = self.serialGetHosts(sshIps)
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug("Successfully get hostnames for all nodes.",
                                  "constant")
            else:
                self.logger.debug("Successfully get hostnames for all nodes.")
        return result

    def writeLocalHosts(self, result):
        """
        function:
         Write hostname and Ip into /etc/hosts when there's not the same one
         in /etc/hosts file
        precondition:
          NA
        postcondition:
           NA
        input: Dictionary result,key is IP and value is hostname
        output: NA
        hideninfo:NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug(
                    "Write local hostname and Ip into /etc/hosts.", "addStep")
            else:
                self.logger.debug(
                    "Write local hostname and Ip into /etc/hosts.")
        hostIPInfo = ""
        if (os.getuid() == 0):
            tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()
            # Check if /etc/hosts exists.
            if (not os.path.exists("/etc/hosts")):
                raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"] +
                                " Error: \nThe /etc/hosts does not exist.")
            cmd = "grep -v '" + HOSTS_MAPPING_FLAG + "' /etc/hosts| grep -v '^$'"
            (status, output) = subprocess.getstatusoutput(cmd)
            try:
                g_file.createFile(tmpHostIpName)
                g_file.changeMode(DefaultValue.KEY_FILE_MODE, tmpHostIpName)
                g_file.writeFile(tmpHostIpName, [output])
                shutil.copyfile(tmpHostIpName, '/etc/hosts')
                g_file.removeFile(tmpHostIpName)
            except Exception as e:
                if os.path.exists(tmpHostIpName):
                    g_file.removeFile(tmpHostIpName)
                raise Exception(str(e))
            if not self.nodeduplicate:
                ipCompare = []
                for line in output.split("\n"):
                    if line:
                        ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
                tmpResult = {}
                for s_key in list(result.keys()):
                    if s_key not in ipCompare:
                        tmpResult[s_key] = result[s_key]
                for (key, value) in tmpResult.items():
                    hostIPInfo += '%s  %s  %s\n' % (key, value, HOSTS_MAPPING_FLAG)
            else:
                for (key, value) in result.items():
                    hostIPInfo += '%s  %s  %s\n' % (key, value, HOSTS_MAPPING_FLAG)
            hostIPInfo = hostIPInfo[:-1]
            ipInfoList = [hostIPInfo]
            g_file.writeFile("/etc/hosts", ipInfoList)
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug(
                    "Successfully write local hostname and Ip into "
                    "/etc/hosts.",
                    "constant")
            else:
                self.logger.debug(
                    "Successfully write local hostname and Ip into "
                    "/etc/hosts.")

    def writeRemoteHostName(self, ip):
        """
        function:
         Write hostname and Ip into /etc/hosts when there's not the same one
         in /etc/hosts file by threading
        precondition:
          NA
        postcondition:
           NA
        input: ip
        output: NA
        hideninfo:NA
        """
        writeResult = []
        result = {}
        tmpHostIpName = "./tmp_hostsiphostname_%d_%s" % (os.getpid(), ip)
        username = pwd.getpwuid(os.getuid()).pw_name
        global ipHostInfo
        try:
            ssh = paramiko.Transport((ip, 22))
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_511["GAUSS_51107"]
                            + " Error: \n%s" % str(e))
        try:
            ssh.connect(username=username, password=self.passwd[0])
        except Exception as e:
            ssh.close()
            raise Exception(ErrorCode.GAUSS_503["GAUSS_50317"]
                            + " Error: \n%s" % str(e))
        cmd = "grep -v '%s' %s  | grep -v '^$'" \
              % (" #Gauss.* IP Hosts Mapping", '/etc/hosts')
        channel = ssh.open_session()
        channel.exec_command(cmd)
        ipHosts = channel.recv(9999).decode().strip()
        errInfo = channel.recv_stderr(9999).decode().strip()
        cmd = "echo \"%s\" > %s ; cp %s %s && rm -rf %s" \
              % (ipHosts, tmpHostIpName, tmpHostIpName, '/etc/hosts', tmpHostIpName)
        channel = ssh.open_session()
        channel.exec_command(cmd)
        ipHosts1 = channel.recv(9999).decode().strip()
        errInfo1 = channel.recv_stderr(9999).decode().strip()
        if ((errInfo + errInfo1)):
            writeResult.append(errInfo + errInfo1)
        else:
            if (not ipHosts1):
                if not self.nodeduplicate:
                    ipCompare = []
                    for line in ipHosts.split("\n"):
                        if line:
                            ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
                    tmpIpHostInfo = ""
                    ipArray = ipHostInfo.split("\n")
                    for info in ipArray:
                        hostname = info.split(' ')[0]
                        if hostname not in ipCompare:
                            tmpIpHostInfo += info + "\n"
                    cmd = "echo '%s' >> /etc/hosts" % tmpIpHostInfo
                else:
                    cmd = "echo '%s' >> /etc/hosts" % ipHostInfo
                channel = ssh.open_session()
                channel.exec_command(cmd)
                errInfo = channel.recv_stderr(9999).decode().strip()
                if (errInfo):
                    writeResult.append(errInfo)
        if channel:
            channel.close()
        result[ip] = writeResult
        if (len(writeResult) > 0):
            return (False, result)
        else:
            return (True, result)

    def writeRemoteHosts(self, result, username, rootPasswd):
        """
        function:
         Write hostname and Ip into /etc/hosts when there's not the same one
         in /etc/hosts file
        precondition:
          NA
        postcondition:
           NA
        input: Dictionary result,key is IP and value is hostname
                    rootPasswd
        output: NA
        hideninfo:NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug(
                    "Write remote hostname and Ip into /etc/hosts.", "addStep")
            else:
                self.logger.debug(
                    "Write remote hostname and Ip into /etc/hosts.")
        global ipHostInfo
        boolInvalidIp = False
        ipHostInfo = ""
        if (os.getuid() == 0):
            writeResult = []
            tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()

            if (len(rootPasswd) == 1):
                result1 = {}
                for (key, value) in result.items():
                    ipHostInfo += '%s  %s  %s\n' % (key, value,
                                                    HOSTS_MAPPING_FLAG)
                    if (value != self.localHost):
                        if (not value in result1.keys()):
                            result1[value] = key

                sshIps = result1.keys()
                ipHostInfo = ipHostInfo[:-1]
                if (sshIps):
                    ipRemoteHostname = parallelTool.parallelExecute(
                        self.writeRemoteHostName, sshIps)
                    errorMsg = ""
                    for (key, value) in ipRemoteHostname:
                        if (not key):
                            errorMsg = errorMsg + '\n' + str(value)
                    if (errorMsg != ""):
                        raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"]
                                        + " Error: %s" % errorMsg)
            else:
                for (key, value) in result.items():
                    if (value == self.localHost):
                        continue
                    for pswd in rootPasswd:
                        try:
                            ssh = paramiko.Transport((key, 22))
                        except Exception as e:
                            self.logger.debug(str(e))
                            boolInvalidIp = True
                            break
                        try:
                            ssh.connect(username=username, password=pswd)
                            break
                        except Exception as e:
                            self.logger.debug(str(e))
                            continue
                    if (boolInvalidIp):
                        boolInvalidIp = False
                        continue
                    cmd = "grep -v '%s' %s | grep -v '^$'" % (
                        " #Gauss.* IP Hosts Mapping", '/etc/hosts')
                    channel = ssh.open_session()
                    channel.exec_command(cmd)
                    ipHosts = channel.recv(9999).decode().strip()
                    errInfo = channel.recv_stderr(9999).decode().strip()
                    cmd = "echo \"%s\" > %s ; cp %s %s && rm -rf %s" % (
                        ipHosts, tmpHostIpName, tmpHostIpName,
                        '/etc/hosts', tmpHostIpName)
                    channel = ssh.open_session()
                    channel.exec_command(cmd)
                    ipHosts1 = channel.recv(9999).decode().strip()
                    errInfo1 = channel.recv_stderr(9999).decode().strip()

                    if (errInfo + errInfo1):
                        writeResult.append(errInfo + errInfo1)
                    else:
                        if (not ipHosts1):
                            ipHostInfo = ""
                            if not self.nodeduplicate:
                                ipCompare = []
                                for line in ipHosts.split("\n"):
                                    if line:
                                        ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
                                for (key1, value1) in result.items():
                                    if key1 not in ipCompare:
                                        ipHostInfo += '%s  %s  %s\n' % (
                                            key1, value1, HOSTS_MAPPING_FLAG)
                            else:
                                for (key1, value1) in result.items():
                                    ipHostInfo += '%s  %s  %s\n' % (
                                        key1, value1, HOSTS_MAPPING_FLAG)
                            ipHostInfo = ipHostInfo[:-1]
                            cmd = "echo '%s' >> /etc/hosts" % ipHostInfo
                            channel = ssh.open_session()
                            channel.exec_command(cmd)
                            errInfo = channel.recv_stderr(
                                9999).decode().strip()
                            if (errInfo):
                                writeResult.append(errInfo)

                    if channel:
                        channel.close()

                if (len(writeResult) > 0):
                    raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"]
                                    + " Error: \n%s" % writeResult)
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug(
                    "Successfully write remote hostname and Ip into "
                    "/etc/hosts.",
                    "constant")
            else:
                self.logger.debug(
                    "Successfully write remote hostname and Ip into "
                    "/etc/hosts.")

    def initLogger(self):
        """
        function: Init logger
        input : NA
        output: NA
        """
        if (self.logFile != ""):
            self.logger = GaussLog(self.logFile, "gs_sshexkey")
        else:
            self.logger = PrintOnScreen()

    def checkNetworkInfo(self):
        """
        function: check  local node to other node Network Information
        input : NA
        output: NA
        """
        self.logOut(2, 0)
        try:
            netWorkList = DefaultValue.checkIsPing(self.hostList)
            if not netWorkList:
                self.logger.log("All nodes in the network are Normal.")
            else:
                self.logger.logExit(ErrorCode.GAUSS_506["GAUSS_50600"]
                                    + "The IP list is:%s." % netWorkList)
        except Exception as e:
            self.logger.logExit(str(e))
        self.logOut(3, 1)

    def run(self):
        """
        function: Do create SSH trust
        input : NA
        output: NA
        """
        self.parseCommandLine()
        self.checkParameter()
        self.localHost = socket.gethostname()

        self.initLogger()
        global tmp_files
        tmp_files = "/tmp/%s" % TMP_TRUST_FILE
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug(
                    "gs_sshexkey execution takes %s steps in total" %
                    ClusterCommand.countTotalSteps("gs_sshexkey", "",
                                                   self.skipHostnameSet))
        Ips = []
        Ips.extend(self.hostList)
        result = self.getAllHosts(Ips)
        self.checkNetworkInfo()

        if not self.skipHostnameSet:
            self.writeLocalHosts(result)
            self.writeRemoteHosts(result, self.user, self.passwd)

        self.logger.log("Creating SSH trust.")
        try:
            self.localID = self.createPublicPrivateKeyFile()
            self.addLocalAuthorized()
            self.updateKnow_hostsFile(result)
            self.addRemoteAuthorization()
            self.determinePublicAuthorityFile()
            self.synchronizationLicenseFile()
            self.verifyTrust()
            self.logger.log("Successfully created SSH trust.")
        except Exception as e:
            self.logger.logExit(str(e))

    def createPublicPrivateKeyFile(self):
        """
        function: create  local public private key file
        input : NA
        output: NA
        """
        self.logOut(4, 0)

        if not os.path.exists(self.id_rsa_pub_fname):
            cmd = 'ssh-keygen -t rsa -N \"\" -f ~/.ssh/id_rsa < /dev/null'
            cmd += "&& chmod %s %s %s" % (DefaultValue.KEY_FILE_MODE,
                                          self.id_rsa_fname,
                                          self.id_rsa_pub_fname)
            (status, output) = subprocess.getstatusoutput(cmd)
            if (status != 0):
                self.logger.log("The cmd is %s " % cmd)
                raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"]
                                + " Error:\n%s" % output)
        try:
            try:
                with open(self.id_rsa_pub_fname, 'r') as f:
                    return f.readline().strip()
            except IOError as e:
                self.logger.debug(str(e))
                raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"]
                                + " Unable to read the generated file."
                                + self.id_rsa_pub_fname)
        finally:
            self.logOut(5, 1)

    def addLocalAuthorized(self):
        """
        function: append the local id_rsa.pub value provided to authorized_keys
        input : NA
        output: NA
        """
        self.logOut(6, 0)
        g_file.createFileInSafeMode(self.authorized_keys_fname)
        with open(self.authorized_keys_fname, 'a+') as f:
            for line in f:
                if line.strip() == self.localID:
                    # The localID is already in authorizedKeys; no need to add
                    return
            f.write(self.localID)
            f.write('\n')
        self.logOut(7, 1)
        g_file.changeMode(DefaultValue.KEY_FILE_MODE,
                          self.authorized_keys_fname)

    def checkAuthentication(self, hostname):
        """
        function: Ensure the proper password-less access to the remote host.
        input : hostname
        output: True/False, hostname
        """
        cmd = 'ssh -n %s %s true' % (DefaultValue.SSH_OPTION, hostname)
        (status, output) = subprocess.getstatusoutput(cmd)
        if (status != 0):
            self.logger.debug("The cmd is %s " % cmd)
            self.logger.debug(
                "Failed to check authentication. Hostname:%s. Error: \n%s"
                % (hostname, output))
            return (False, hostname)
        return (True, hostname)

    def updateKnow_hostsFile(self, result):
        """
        function: keyscan all hosts and update known_hosts file
        input : result
        output: NA
        """
        self.logOut(8, 0)
        hostnameList = []
        hostnameList.extend(self.hostList)
        for (key, value) in result.items():
            hostnameList.append(value)
        for hostname in hostnameList:
            cmd = 'ssh-keyscan -t rsa %s >> %s ' % (hostname,
                                                    self.known_hosts_fname)
            cmd += "&& chmod %s %s" % (DefaultValue.KEY_FILE_MODE,
                                       self.known_hosts_fname)
            (status, output) = subprocess.getstatusoutput(cmd)
            if (status != 0):
                raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % cmd
                                + " Error:\n%s" % output)
        (status, output) = self.checkAuthentication(self.localHost)
        if not status:
            raise Exception(
                ErrorCode.GAUSS_511["GAUSS_51100"] % self.localHost)
        self.logOut(9, 1)

    def tryParamikoConnect(self, hostname, client, pswd=None, silence=False):
        """
        function: try paramiko connect
        input : hostname, client, pswd, silence
        output: True/False
        """
        try:
            client.connect(hostname, password=pswd, allow_agent=False,
                           look_for_keys=False)
            return True
        except paramiko.AuthenticationException as e:
            if not silence:
                self.logger.debug("Incorrect password. Node: %s." % hostname
                                  + " Error:\n%s" % str(e))
            client.close()
            return False
        except Exception as e:
            if not silence:
                self.logger.debug('[SSHException %s] %s' % (hostname, str(e)))
            client.close()
            raise Exception(str(e))

    def addRemoteAuthorization(self):
        """
        function: Send local ID to remote over SSH, and append to
        authorized_key
        input : NA
        output: NA
        """
        self.logOut(10, 0)
        try:
            parallelTool.parallelExecute(self.sendRemoteAuthorization,
                                         self.hostList)
            if (self.incorrectPasswdInfo != ""):
                self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"]
                                    % (self.incorrectPasswdInfo.rstrip("\n")))
            if (self.failedToAppendInfo != ""):
                self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"]
                                    % (self.failedToAppendInfo.rstrip("\n")))
        except Exception as e:
            self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51111"]
                                + " Error:%s." % str(e))
        self.logOut(11, 1)

    def sendRemoteAuthorization(self, hostname):
        """
        function: send remote authorization
        input : hostname
        output: NA
        """
        if (hostname != self.localHost):
            p = None
            cin = cout = cerr = None
            try:
                # ssh Remote Connection other node
                p = paramiko.SSHClient()
                p.load_system_host_keys()
                ok = self.tryParamikoConnect(hostname, p, self.passwd[0],
                                             silence=True)
                if not ok:
                    for pswd in self.passwd[1:]:
                        ok = self.tryParamikoConnect(hostname, p, pswd,
                                                     silence=True)
                        if ok:
                            break
                if not ok:
                    self.incorrectPasswdInfo += "Without this node[%s] of " \
                                                "the correct password.\n" % \
                                                hostname
                    return
                # Create .ssh directory and ensure content meets permission
                # requirements for password-less SSH
                cmd = ('mkdir -p .ssh; ' + "chown -R %s:%s %s; " %
                       (self.user, self.group, self.sshDir) + 'chmod %s .ssh; '
                       % DefaultValue.KEY_DIRECTORY_MODE
                       + 'touch .ssh/authorized_keys; '
                       + 'touch .ssh/known_hosts; '
                       + 'chmod %s .ssh/auth* .ssh/id* .ssh/known_hosts; '
                       % DefaultValue.KEY_FILE_MODE)
                (cin, cout, cerr) = p.exec_command(cmd)
                cin.close()
                cout.close()
                cerr.close()

                # Append the ID to authorized_keys;
                cnt = 0
                cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ok' \
                      % self.localID
                (cin, cout, cerr) = p.exec_command(cmd)
                cin.close()
                # readline will read other msg.
                line = cout.read().decode()
                while (line.find("ok ok ok") < 0):
                    time.sleep(cnt * 2)
                    cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ' \
                          'ok' % self.localID
                    (cin, cout, cerr) = p.exec_command(cmd)
                    cin.close()
                    cnt += 1
                    line = cout.readline()
                    if (cnt >= 3):
                        break
                    if (line.find("ok ok ok") < 0):
                        continue
                    else:
                        break

                if (line.find("ok ok ok") < 0):
                    self.failedToAppendInfo += "...send to %s\nFailed to " \
                                               "append local ID to " \
                                               "authorized_keys on remote " \
                                               "node %s.\n" % (
                                                   hostname, hostname)
                    return
                cout.close()
                cerr.close()
                self.logger.debug(
                    "Send to %s\nSuccessfully appended authorized_key on "
                    "remote node %s." % (hostname, hostname))
            finally:
                if cin:
                    cin.close()
                if cout:
                    cout.close()
                if cerr:
                    cerr.close()
                if p:
                    p.close()

    def determinePublicAuthorityFile(self):
        '''
        function: determine common authentication file content
        input : NA
        output: NA
        '''
        self.logOut(12, 0)
        # eliminate duplicates in known_hosts file
        try:
            tab = self.readKnownHosts()
            self.writeKnownHosts(tab)
        except IOError as e:
            self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"]
                                % "known hosts file" + " Error:\n%s" % str(e))

        # eliminate duploicates in authorized_keys file
        try:
            tab = self.readAuthorizedKeys()
            self.writeAuthorizedKeys(tab)
        except IOError as e:
            self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"]
                                % "authorized keys file" + " Error:\n%s"
                                % str(e))
        self.logOut(13, 1)

    def addRemoteID(self, tab, line):
        """
        function: add remote node id
        input : tab, line
        output: True/False
        """
        IDKey = line.strip().split()
        if not (len(IDKey) == 3 and line[0] != '#'):
            return False
        tab[IDKey[2]] = line
        return True

    def readAuthorizedKeys(self, tab=None, keysFile=None):
        """
        function: read authorized keys
        input : tab, keysFile
        output: tab
        """
        if not keysFile:
            keysFile = self.authorized_keys_fname
        if not tab:
            tab = {}
        with open(keysFile, 'r') as f:
            for line in f:
                self.addRemoteID(tab, line)
        return tab

    def writeAuthorizedKeys(self, tab, keysFile=None):
        """
        function: write authorized keys
        input : tab, keysFile
        output: True/False
        """
        if not keysFile:
            keysFile = self.authorized_keys_fname
        with open(keysFile, 'w') as f:
            for IDKey in tab:
                f.write(tab[IDKey])

    def addKnownHost(self, tab, line):
        """
        function: add known host
        input : tab, line
        output: True/False
        """
        key = line.strip().split()
        if not (len(key) == 3 and line[0] != '#'):
            return False
        tab[key[0]] = line
        return True

    def readKnownHosts(self, tab=None, hostsFile=None):
        """
        function: read known host
        input : tab, hostsFile
        output: tab
        """
        if not hostsFile:
            hostsFile = self.known_hosts_fname
        if not tab:
            tab = {}
        with open(hostsFile, 'r') as f:
            for line in f:
                self.addKnownHost(tab, line)
        return tab

    def writeKnownHosts(self, tab, hostsFile=None):
        """
        function: write known host
        input : tab, hostsFile
        output: NA
        """
        if not hostsFile:
            hostsFile = self.known_hosts_fname
        with open(hostsFile, 'w') as f:
            for key in tab:
                f.write(tab[key])

    def sendTrustFile(self, hostname):
        '''
        function: Set or update the authentication files on  hostname
        input : hostname
        output: NA
        '''
        cmd = ('scp -q -o "BatchMode yes" -o "NumberOfPasswordPrompts '
               '0" ' + '%s %s %s %s %s:.ssh/' % (
                   self.authorized_keys_fname, self.known_hosts_fname,
                   self.id_rsa_fname, self.id_rsa_pub_fname, hostname))
        (status, output) = subprocess.getstatusoutput(cmd)
        if (status != 0):
            raise Exception(ErrorCode.GAUSS_502["GAUSS_50223"]
                            % "the authentication" + " Node:%s. Error:\n%s."
                            % (hostname, output) + "The cmd is %s " % cmd)

    def synchronizationLicenseFile(self):
        '''
        function: Distribution of documents through concurrent execution
        ThreadPool.
        input : NA
        output: NA
        '''
        self.logOut(14, 0)
        try:
            parallelTool.parallelExecute(self.sendTrustFile, self.hostList)
        except Exception as e:
            self.logger.logExit(str(e))
        self.logOut(15, 1)

    def verifyTrust(self):
        """
        function: Verify creating SSH trust is successful
        input : NA
        output: NA
        """
        self.logOut(16, 0)
        try:
            results = parallelTool.parallelExecute(self.checkAuthentication,
                                                   self.hostList)
            hostnames = ""
            for (key, value) in results:
                if (not key):
                    hostnames = hostnames + ',' + value
            if (hostnames != ""):
                raise Exception(ErrorCode.GAUSS_511["GAUSS_51100"]
                                % hostnames.lstrip(','))
        except Exception as e:
            self.logger.logExit(str(e))
        self.logOut(17, 1)

    def getUserPasswd(self):
        """
        function: get user passwd from cache
        input: NA
        output: NA
        """
        user_passwd = []
        if (sys.stdin.isatty()):
            GaussLog.printMessage(
                "Please enter password for current user[%s]." % self.user)
            user_passwd.append(getpass.getpass())
        else:
            user_passwd.append(sys.stdin.readline().strip('\n'))

        if (not user_passwd):
            GaussLog.exitWithError("Password should not be empty")

        return user_passwd


if __name__ == '__main__':
    # main function
    createTrust = None
    try:
        createTrust = GaussCreateTrust()
        createTrust.run()
    except Exception as e:
        if str(e).startswith("[GAUSS-"):
            GaussLog.exitWithError(str(e))
        else:
            GaussLog.exitWithError("[GAUSS-50100]:"+str(e))

    sys.exit(0)
